
{--
 * This package implements the first compiler pass after lexical analysis.
 * It fixes the list of definitions by incorporating document comments into
 * definitions, and joining separate equations of functions that come out as
 * separate 'FunDcl' definitions from parsing.

 * The function 'fixdefs' will also be used from other modules for fixing
 * sub definitions.
 -}


package frege.compiler.passes.Fix where

import frege.Prelude hiding(<+>, break)

-- import Data.TreeMap(insertkv)

import Lib.PP (break, group, msgdoc, text, <+>)
-- import frege.compiler.Utilities   as U()

import  Compiler.enums.TokenID(VARID, CONID)
import  Compiler.enums.CaseKind

import  Compiler.types.SNames(Simple, With1)
-- import  Compiler.types.QNames
import  Compiler.types.Tokens(baseTokenAt, KeyTk, Token)
import  Compiler.types.Positions(Position, Positioned)
import  Compiler.types.SourceDefinitions
import  Compiler.types.Global as G

import  Compiler.common.Errors as E()
import  Compiler.common.Binders
import  Compiler.common.Mangle (noClashIdent)

import Compiler.classes.Nice
import Compiler.instances.NiceExprS


pass :: StG (String, Int)
pass = do
    g <- getST
    newdefs <- fixdefs g.sub.sourcedefs
    changeST Global.{sub <- _.{sourcedefs = newdefs}}
    return ("definitions", length g.sub.sourcedefs)


{--
    transform definitions by first 'unDoc'ing, then 'funJoin'ing, 
    then desugaring pattern bindings
 -}
fixdefs :: [DefinitionS] -> StG [DefinitionS]
fixdefs defs = do
    ds <- unDoc defs >>= mapM unlet >>= pure . concat
    fs <- funJoin ds
    checkUniq [ name | dcl@FunDcl {lhs} <- fs, name <- funbinding dcl ]
    return fs


checkUniq [] = return ()
checkUniq (name:as) = do
    let other = [ same | same <- as, same.value == name.value ]
    unless (null other) do
        E.error (Pos (head other) (head other)) (msgdoc ("redefinition of  `"
                                                ++ name.value
                                                ++ "` introduced line " ++ show name.line))
    checkUniq as

{--
    Desugar pattern bindings
    > (a,b) = expr
    that come as pseudo function bindings
    > let (a, b) = expr
    to
    > let = expr
    > a = case let of (a,b) -> a
    > b = case let of (a,b) -> b
    -}
unlet :: DefinitionS -> StG [DefinitionS]
unlet f
    | FunDcl{vis, lhs, pats, expr, doc} <- f,
      patbinding f = do
        let pos = getpos lhs
            name = "gen$" ++ show pos.first.offset
            vlet = Vbl (Simple pos.first.{tokid=VARID, value=name})
            -- gen$12345 = expr  
            it = f.{lhs=vlet, pats=[], positions=[vlet.name.id]}
            -- restore original pattern
            pat = fold App lhs pats
            -- for each variable in pat a, b, ...
            vars = exvars pat
            -- a = case gen$12345 of pat -> a
            -- b = case gen$12345 of pat -> b
            ds  = map mkFun vars
            -- if there is just one variable in the pattern,
            -- we don't need extra let
            -- Just a = something
            -- a = case something of Just a -> a
            cex = if length vars == 1 then expr else vlet
            res = if length vars == 1 then ds   else it:ds  
            mkFun (var@Vbl{}) = f.{lhs=var, pats=[],
                    positions = [var.name.id], 
                    expr=Case CNormal cex [CAlt{pat, ex=var}]}
            mkFun _ = error "mkFun"
        when (null vars) do
            g <- getST
            E.error pos (msgdoc "Left hand side of pattern binding "
                        <+> text (nicer pat g)
                        <+> msgdoc " contains no variables.")
        pure res
    | FunDcl{lhs, pats} <- f,
      Nothing <- funbinding f,                  -- must be !pat = or ?pat = 
      [pat]   <- pats,
      p       <- f.{lhs=pat, pats=[]},          -- pat = ...
      patbinding p = do
        ps <- unlet p                         -- unlet it
        case ps of
            x:xs -> return (x.{lhs=f.lhs, pats=[x.lhs]}:xs)
            []   -> error "Can't happen as unlet doesn't return []"
    | otherwise = pure [f]                                      



{--
 * apply a series of docs to a subsequent definition
 -}
unDoc :: [DefinitionS] -> StG [DefinitionS]
unDoc [] = stio []
unDoc (defs@(d:ds))
    | DocDcl {} <- d         = do r <- apply doc rest; unDoc r
    | d.{defs?}              = do ndefs <- fixdefs d.defs
                                  liftM2 (:) (stio d.{defs=ndefs}) (unDoc ds)
    | otherwise              = liftM2 (:) (stio d) (unDoc ds)
    where
        pos  = d.pos
        docs = takeWhile isDoc defs
        rest = dropWhile isDoc defs
        isDoc (DocDcl {}) = true
        isDoc _           = false
        doc  = joined "\n\n" (map DefinitionS.text docs)
        apply :: String -> [DefinitionS] -> StG [DefinitionS]
        apply str []     = do E.warn pos (msgdoc ("documentation at end of file")); stio []
        apply str (d:ds) = case d of
            ImpDcl {pos=p} -> do
                E.warn p (msgdoc ("there is no point in documenting an import, documentation from line "
                    ++ show pos ++ " ignored."))
                stio (d:ds)
            FixDcl {pos=p} -> do
                E.warn p (msgdoc ("there is no point in documenting a fixity declaration, documentation from line "
                    ++ show pos ++ " ignored."))
                stio (d:ds)
            def | Just s <- def.doc = stio (def.{doc = Just (str ++ "\n\n" ++ s)}  : ds)
                | otherwise         = stio (def.{doc = Just str} : ds)

{--
 * look for adjacent function definitions with same name and join them
 -}
funJoin :: [DefinitionS] -> StG [DefinitionS]
funJoin [] = return []
funJoin (defs@(d:ds))
    | FunDcl {lhs} <- d, Just name <- funbinding d 
    = do 
        joined <- joinFuns (Pos name name) (funs name)
        rest   <- funJoin  (next name)
        return (joined:rest)
    | otherwise = do
        rest   <- funJoin ds
        return (d:rest)
    where
        funs name = takeWhile (sameFun name) defs
        next name = dropWhile (sameFun name) defs
        sameFun name fundcl | Just n <- funbinding fundcl = n.value == name.value
        sameFun name _ = false
        joinFuns :: Position -> [DefinitionS] -> StG DefinitionS
        joinFuns pos [f] = return f.{positions=[pos.first]}
        joinFuns pos (fs@(f:_))
            | null f.pats = do
                    E.error pos (msgdoc "function binding without patterns must have only a single equation")
                    return f
            | (g:_) <- filter (\x -> DefinitionS.vis x != f.vis) fs = do
                    E.error (getpos g.lhs) (msgdoc ("the visibility of " ++ g.name ++
                                    " must match that of the equation in line " ++ show pos))
                    stio f
            | (g:_) <- filter (\x -> length (DefinitionS.pats x) != length f.pats) fs = do
                    E.error (getpos g.lhs) (msgdoc ("number of patterns (" ++ show (length g.pats) ++
                                   ") must be the same as in previous equations (" ++
                                   show (length f.pats)))
                    return f
            | otherwise = stio result  -- all equations have same # of patterns and visibility
            where
                arity  = length f.pats

                result = f.{pats = newvars, expr = newexpr, doc = newdoc,
                            positions = mapMaybe funbinding fs}
                newvars = [ Vbl  (Simple pos.first.{tokid=VARID, value=noClashIdent i})  | i <- take arity allAsciiBinders]
                -- newpats = [ PVar (pos.change VARID ("_"++i)) 0 ("_" ++ i) | i <- take arity allAsciiBinders]
                newexpr = Case CNormal (mkTuple Con pos newvars) alts 
                alts    = [ CAlt {pat=mkpTuple (getpos g.lhs) g.pats, ex = g.expr} |
                             (g::DefinitionS) <- fs ]
                olddoc  = [ s | Just s <- map DefinitionS.doc fs ]
                newdoc  = if null olddoc then Nothing else Just (joined "\n\n" olddoc)
        joinFuns _ [] = error "fatal compiler error: joinFuns []"

--- create a constructor for an n-tuple
tuple n = let
        i = n-1
        commas = repeat ","
        string = "(" ++ fold (++) "" (take i commas) ++ ")"
    in string


tupleName i t = With1 (baseTokenAt t) 
                      (baseTokenAt t).{tokid=CONID, value=tuple i}


mkTuple con (pos::Position) [x] = x
mkTuple con pos xs = fold App (con (tupleName n pos.first)) xs
    where !n = length xs


mkpTuple (pos::Position) [p] = p
mkpTuple pos xs = fold App  (Con (tupleName n pos.first)) xs
    where !n = length xs
